#include <stdint.h>
#include <stdlib.h>
#include <stdio.h>
#include <errno.h>
#include <string.h>
#include <sys/types.h>
#include <assert.h>
#include <memory.h>
#include <sys/shm.h>

#include "buddy.h"
#include "configure.h"
#include "mem_pool.h"
#include "config.h"
#include "dbg.h"

#define LICH_TLS_HUGE_PAGE      1

#define MEM_POOL_DEFAULT_PAGE_COUNT     2048

#define MEM_POOL_TYPE_MALLOC            0
#define MEM_POOL_TYPE_HUGE              1

#ifndef round_up
#define round_up(x, y)                  (((x) % (y) == 0)? (x) : (x + y) / y * y)
#endif

static __thread mem_pool_t *g_mem_pools[MAX_MEMORY_TYPE_INDEX] = {0};

static inline uint32_t round_up_to_pow_of_2(uint32_t val)
{
        uint32_t sz = 1;
        while (sz < val)
                sz <<= 1;

        val = sz;

        return val;
}

static inline uint32_t round_down_to_power_of_2(uint32_t val)
{
        uint32_t sz = 1;
        while (sz <= val)
                sz <<= 1;

        val = sz / 2;

        return val;
}

static inline uint32_t mem_pool_page_size(uint32_t total_size)
{
        uint32_t page_size = total_size / MEM_POOL_DEFAULT_PAGE_COUNT;

        page_size = round_up_to_pow_of_2(page_size);

        return page_size;
}

void *huge_tls_malloc(size_t sz);
void huge_tls_free(void *ptr);

static inline void mem_pool_free_pool(uint8_t *pool_buf, int shmid)
{
#if LICH_TLS_HUGE_PAGE
        DINFO("free big pool, ptr:%p\r\n", pool_buf);
        return huge_tls_free(pool_buf);
#endif

        if (shmid >= 0)
        {
                if (shmdt(pool_buf) != 0)
                {
                        DERROR("shmem detach failure (errno=%d %m)", errno);
                }
        }
        else
        {
                free(pool_buf);
        }
}

static inline void *mem_pool_alloc_pool(uint32_t pool_size, int pool_type, int *shmid)
{
        int shmemid;
        uint8_t *buf;
        void *ptr;

#if LICH_TLS_HUGE_PAGE
        ptr = huge_tls_malloc(pool_size);
        DINFO("allocate big pool, size: %u, ptr:%p\r\n", pool_size, ptr);

        return ptr;
#endif

        if(pool_type != MEM_POOL_TYPE_HUGE)
                goto malloc_page;

        /* allocate memory */
        shmemid = shmget(IPC_PRIVATE, pool_size, SHM_HUGETLB | IPC_CREAT | SHM_R | SHM_W);

        if (shmemid < 0)
        {
                DERROR("shmget rdma pool sz:%d failed\n", pool_size);

                return NULL;
        }

        /* get pointer to allocated memory */
        buf = shmat(shmemid, NULL, 0);

        if (buf == (void *)-1)
        {
                DERROR("Shared memory attach failure (errno=%d %m)", errno);
                shmctl(shmemid, IPC_RMID, NULL);
                
                return NULL;
        }

        /* mark 'to be destroyed' when process detaches from shmem segment
           this will clear the HugePage resources even if process if killed not nicely.
           From checking shmctl man page it is unlikely that it will fail here. */
        if (shmctl(shmemid, IPC_RMID, NULL))
        {
                DERROR("Shared memory contrl mark 'to be destroyed' failed (errno=%d %m)", errno);
        }

        DBUG("Allocated huge page sz:%d\n", pool_size);
        *shmid = shmemid;

        return buf;

malloc_page:
        *shmid = -1;
        return valloc(pool_size);
}

int mem_pool_create(mem_pool_t *pool, uint64_t size, int flags)
{
        uint32_t buddy_size;
        uint32_t page_count;
        buddy_t *buddy;

        pool->flags = flags;
        
        page_count = size / MEM_POOL_PAGE_SIZE;
        page_count = round_down_to_power_of_2(page_count); //round down to power of 2.

        buddy_size = BUDDY_SIZE(page_count);
        buddy_size = round_up(buddy_size, MEM_POOL_PAGE_SIZE);
        buddy_size = round_up_to_pow_of_2(buddy_size);

        printf("buddy_size = %d\r\n", buddy_size);

        pool->buddy_addr = pool->base_addr = mem_pool_alloc_pool(buddy_size, MEM_POOL_TYPE_HUGE, &pool->shmid[0]);
        if(!pool->buddy_addr)
                goto err_ret;

        pool->mem_pool_addr = mem_pool_alloc_pool(size, MEM_POOL_TYPE_HUGE, &pool->shmid[1]);
        if(!pool->mem_pool_addr)
                goto err_ret;
        //pool->buddy_addr + buddy_size;
        memset(pool->buddy_addr, 0, buddy_size);

        pool->page_size = MEM_POOL_PAGE_SIZE;//mem_pool_page_size(size - buddy_size);
        pool->page_count = page_count;
        pool->next = NULL;

        buddy = (buddy_t *)pool->buddy_addr;

        return buddy_init(buddy, pool->page_count);

err_ret:
        assert(0);
        if(pool->buddy_addr)
                mem_pool_free_pool(pool->buddy_addr, pool->shmid[0]);
        if(pool->mem_pool_addr)
                mem_pool_free_pool(pool->mem_pool_addr, pool->shmid[1]);

        return ENOMEM;
}

void mem_pool_destory(mem_pool_t *pool)
{
        if(pool->buddy_addr)
                mem_pool_free_pool(pool->buddy_addr, pool->shmid[0]);
        if(pool->mem_pool_addr)
                mem_pool_free_pool(pool->mem_pool_addr, pool->shmid[1]);

        pool->buddy_addr = pool->mem_pool_addr;
}

void * __mem_pool_alloc(mem_pool_t *pool, uint32_t mem_size)
{
        int index;
        int count;
        buddy_t *buddy = (buddy_t *)pool->buddy_addr;

        count = mem_size / pool->page_size;
        if(mem_size % pool->page_size)
                count ++;

        count = round_up_to_pow_of_2(count);

        index = buddy_alloc(buddy, count);
        if(index < 0) {
                //assert(0);
                return NULL;
        }

        assert(index >=0 && index + count <= pool->page_count);

        return pool->mem_pool_addr + index * pool->page_size;
}

void * mem_pool_alloc(mem_pool_t *pool, uint32_t mem_size)
{
        int ret;
        mem_pool_t *prev = pool;
        void *ptr;

        while(pool) {
                ptr = __mem_pool_alloc(pool, mem_size);
                if(likely(ptr))
                        return ptr;

                prev = pool;
                pool = pool->next;
        }

        DINFO("memory pool extend new, pool size: %u\r\n", (uint64_t)prev->page_size * prev->page_count * 2);

        prev->next = malloc(sizeof(mem_pool_t));
        ret = mem_pool_create(prev->next, (uint64_t)prev->page_size * prev->page_count * 2, prev->flags);
        if(unlikely(ret)) {
                free(prev->next);
                return NULL;
        }

        ptr = __mem_pool_alloc(prev->next, mem_size);
        YASSERT(ptr);

        return ptr;
}

void mem_pool_free(mem_pool_t *pool, void *mem_addr)
{
        int index;
        buddy_t *buddy;
        
        while(unlikely(pool->next && 
                (mem_addr < pool->mem_pool_addr || mem_addr >= pool->mem_pool_addr + (uint64_t)pool->page_size * pool->page_count) ))
                pool = pool->next;
                
        assert(pool);

        buddy = (buddy_t *)pool->buddy_addr;

        index = (mem_addr - pool->mem_pool_addr) / pool->page_size;

        buddy_free(buddy, index);
}


mem_pool_t * mem_pool_get(int index)
{
        int ret;
        uint32_t size = DEFAULT_MEM_POOL_SIZE(index);

        if(g_mem_pools[index])
                return g_mem_pools[index];

        g_mem_pools[index] = malloc(sizeof(mem_pool_t));
        if(!g_mem_pools[index])
                return NULL;

        ret = mem_pool_create(g_mem_pools[index], (uint64_t)size, 0);
        if(ret) {
                free (g_mem_pools[index]);
                g_mem_pools[index] = NULL;
                
                return NULL;
        }

        return g_mem_pools[index];
}

void mem_pool_put(int index)
{
        mem_pool_destory(g_mem_pools[index]);
        free(g_mem_pools[index]);
        g_mem_pools[index] = NULL;
}